[TRTLLM-12347][feat] enable VSA in VisualGen#14280
Conversation
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #49010 [ run ] triggered by Bot. Commit: |
|
PR_Github #49010 [ run ] completed with state
|
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #49483 [ run ] triggered by Bot. Commit: |
|
PR_Github #49483 [ run ] completed with state
|
|
Suggested restructuring: reuse the Hi, @o-stoner ! I checked in with the @zhenhuaw-me today and we'd like to coordinate the VSA integration so it composes with the CuTe-DSL backend that #13721 is about to land. Below is a concrete restructuring proposal — happy to discuss alternatives if any of these don't fit your kernel's constraints. Context (what PR #13721 brings):
Requested changes for #14280:
Let me know what you think — or if any of these conflict with constraints I'm missing (e.g., kernel availability, config, etc.). |
9cc9858 to
22b2f5d
Compare
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #51445 [ run ] triggered by Bot. Commit: |
|
PR_Github #51445 [ run ] completed with state
|
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #51646 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis pull request adds comprehensive Video Sparse Attention (VSA) support to TensorRT-LLM's visual generation framework for Blackwell GPUs. It includes a new CUTE DSL persistent kernel with custom scheduler and PTX primitives, integration into CuTeDSLAttention and distributed attention backends, Wan pipeline orchestration with per-step metadata building, and extensive test coverage validating correctness, equivalence, performance, and multi-GPU distributed execution. ChangesVSA Configuration and Type System
CuTe DSL Persistent Kernel Implementation
Backend and Module Integration
Wan Pipeline VSA Orchestration
VSA Test Coverage
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 10
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py`:
- Around line 201-216: The module-global _vsa_forward_context must be replaced
with a request-local contextvar to avoid cross-request clobbering: create a
contextvars.ContextVar[Optional[VSAMetadata]] (e.g. _vsa_forward_context_var)
and update set_vsa_forward_context to set the ContextVar and yield while
storing/resetting the returned token on exit, and update get_vsa_forward_context
to return _vsa_forward_context_var.get(None); keep the same function/class names
(set_vsa_forward_context, get_vsa_forward_context, VSAMetadata,
_vsa_forward_context -> _vsa_forward_context_var) so callers don’t change.
- Around line 527-541: The CuTe branch currently asserts when num_cubes exceeds
VSA_KERNEL_MAX_CUBES; instead modify the gating so the code falls back to dense
SDPA: include the condition num_cubes <= VSA_KERNEL_MAX_CUBES in the computation
of use_cute (the boolean used to choose the CuTe kernel), and remove or replace
the subsequent assert in the CuTe branch (the block referencing
VSA_KERNEL_MAX_CUBES and num_cubes) so oversized inputs simply skip CuTe and use
the existing dense fallback.
In
`@tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py`:
- Line 416: The hardcoded limit self.max_indices = 4 * 1024 can be exceeded by
variable_block_sizes, causing a shared-memory overflow when copying into
sVariable_block_sizes; add a runtime validation that
variable_block_sizes.shape[0] <= self.max_indices before the copy (or
assert/raise a clear error) and fail fast with a descriptive message, and/or
enforce the check earlier in block_sparse_attn_from_indices_cute in interface.py
so callers cannot pass larger arrays; update any related docs/comments to state
the max_indices constraint and reference the symbols max_indices,
sVariable_block_sizes, variable_block_sizes, and
block_sparse_attn_from_indices_cute when making the change.
In
`@tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/ptx.py`:
- Around line 144-160: The inline assembly in the else branch currently emits a
shared-scope atomic ("atom.relaxed.shared::cta.cta.max.s32") but this path
targets global memory; update the asm string in the llvm.inline_asm call to use
the global scope ("atom.relaxed.global::cta.cta.max.s32") while keeping the same
operand ($0) and constraints, i.e., modify the asm literal passed to
llvm.inline_asm (the triple-quoted string) to replace "shared" with "global" so
the global-memory atomic is emitted.
In `@tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py`:
- Around line 542-555: The code mutates the shared scheduler config when
applying a user flow_shift override (variables: flow_shift, sched_cfg,
shift_key, self.scheduler.register_to_config), which makes the change persist
across requests; instead, apply the override only request-scoped by either
restoring the original sched_cfg[shift_key] after the request or by creating a
request-local copy of the scheduler/config before calling set_timesteps();
specifically, capture the original value (orig_shift =
sched_cfg.get(shift_key)), call register_to_config only on a
cloned/configured-local scheduler or restore orig_shift via
register_to_config(**{shift_key: orig_shift}) after completing the request so
the shared scheduler config is not permanently mutated.
In `@tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py`:
- Around line 489-503: The current code calls
self.scheduler.register_to_config(...) to apply a per-request flow_shift, which
mutates the shared scheduler and leaks that override to subsequent requests;
instead avoid mutating the shared scheduler by creating a request-local
scheduler/config or restoring the original value after use: either clone the
scheduler or its config (e.g., copy sched_cfg = dict(self.scheduler.config) and
apply the flow_shift to that local config or instantiate a shallow copy of the
scheduler) and use that local scheduler/config before calling set_timesteps(),
or if you must modify self.scheduler temporarily, capture the original
sched_cfg[shift_key] first and restore it immediately after the request
completes; reference flow_shift, sched_cfg, self.scheduler, register_to_config,
and set_timesteps when making the change.
In `@tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py`:
- Around line 369-390: The new VSA gate projections to_gate_compress and
to_gate_fine are created as full-width dense linears on every rank, which
misaligns with attn1's TP-local Q shards and duplicates work when tp_size>1;
change their construction to the same column-parallel/sharded setup used for the
Q projection (i.e., mirror the Q Linear creation parameters: use the same
mapping/partitioning, quant_config, skip_create_weights_in_init,
force_dynamic_quantization, and out-dim q_dim) so each rank only holds its
TP-local slice and the gate tensors line up with attn1's local Q shard. Ensure
you reference and reuse the same sharding/mapping pattern used when creating the
Q projection to_gate (or whichever variable constructs Q) so topology and sizes
match across ranks.
In `@tensorrt_llm/_torch/visual_gen/modules/attention.py`:
- Around line 471-475: The _reshape_gate helper reshapes gate tensors using the
global self.num_attention_heads which desyncs under tensor-parallelism; update
_reshape_gate (used for gate_compress / gate_fine) to compute the head count
from the incoming gate tensor (or use the same local head count used when
reshaping q/k/v) instead of self.num_attention_heads, then apply view/transpose
logic with that derived local_head_count so the final layout matches the
attention tensors and respects backend_layout (AttentionTensorLayout.HND)
handling.
In `@tests/unittest/_torch/visual_gen/test_attention_integration.py`:
- Around line 620-628: After constructing integrated (Attention(...,
config=cfg_vsa)), add an explicit assertion that the VSA path was chosen by
invoking the internal selector or flag (call integrated._build_vsa_setup() or
inspect any backend attribute set by that method) and assert it indicates
CUTEDSL/VSA; e.g., ensure the result/attribute equals the expected VSA backend
before proceeding to use integrated in the test so the test fails if CUTEDSL
silently falls back to dense.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: da759100-77d8-4987-90f4-527b2030545e
📒 Files selected for processing (26)
tensorrt_llm/_torch/visual_gen/attention_backend/__init__.pytensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.pytensorrt_llm/_torch/visual_gen/attention_backend/parallel.pytensorrt_llm/_torch/visual_gen/attention_backend/utils.pytensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/__init__.pytensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.pytensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/interface.pytensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/ptx.pytensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/scheduler.pytensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.pytensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.pytensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.pytensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.pytensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.pytensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.pytensorrt_llm/_torch/visual_gen/modules/attention.pytensorrt_llm/_torch/visual_gen/pipeline_loader.pytensorrt_llm/visual_gen/__init__.pytensorrt_llm/visual_gen/args.pytensorrt_llm/visual_gen/params.pytensorrt_llm/visual_gen/sparse_attention.pytests/integration/test_lists/test-db/l0_b200.ymltests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.pytests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.pytests/unittest/_torch/visual_gen/test_attention_integration.pytests/unittest/_torch/visual_gen/test_attention_perf.py
|
PR_Github #51646 [ run ] completed with state
|
6fe39d4 to
6079294
Compare
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #51898 [ run ] triggered by Bot. Commit: |
|
PR_Github #51898 [ run ] completed with state
|
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #54358 [ run ] triggered by Bot. Commit: |
|
PR_Github #54358 [ run ] completed with state
|
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #54665 [ run ] triggered by Bot. Commit: |
|
PR_Github #54665 [ run ] completed with state
|
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #55050 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #55060 [ run ] triggered by Bot. Commit: |
|
PR_Github #55050 [ run ] completed with state |
chang-l
left a comment
There was a problem hiding this comment.
CI coverage for test_wan_vsa_ulysses.py (8-GPU, cfg=2 × ulysses=4) — please confirm it actually runs before merge.
The test is collected via the unittest/_torch/visual_gen/multi_gpu directory entry in l0_dgx_b200.yml, which lives under the system_gpu_count: 8 / stage: post_merge / backend: pytorch condition. Two problems:
--add-multi-gpu-test only adds pre-merge multi-GPU stages, so it will not trigger this. Post-merge tests need /bot run --stage-list "" (or the heavy /bot run --post-merge).
Could you run python scripts/test_to_stage_mapping.py --tests "test_wan_vsa_ulysses" on this branch and confirm which stage runs it, then trigger that stage (e.g. /bot run --stage-list "") and verify it passes before merge?
|
PR_Github #55060 [ run ] completed with state
|
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
3730e37 to
56a9400
Compare
Signed-off-by: o-stoner <245287810+o-stoner@users.noreply.github.com>
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #55315 [ run ] triggered by Bot. Commit: |
@chang-l |
|
PR_Github #55315 [ run ] completed with state
|
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #55536 [ run ] triggered by Bot. Commit: |
|
PR_Github #55536 [ run ] completed with state
|
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
|
/bot run --disable-fail-fast --add-multi-gpu-test |
1 similar comment
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #55598 [ run ] triggered by Bot. Commit: |
|
PR_Github #55598 [ run ] completed with state
|
Summary by CodeRabbit
New Features
flow_shiftparameter to override scheduler configuration in Wan pipelines during inference (allows us to have an apples-to-apples quality comparison with FastVideo, where the Wan pipelines have differentflow_shiftvalues that what exists by default in the scheduler).VideoSparseAttentionConfigfor controlling VSA sparsity levels.Enhancements
Description
Adds VSA attention backend for TRT-LLM VisualGen based on the following VSA paper. Integrates the B200 CuteDSL kernel here. Currently, this backend is supported for Wan 2.1 using the following fine-tuned model from FastVideo. This support will be extended to Wan 2.2 T2V 14B / TI2V 5B once ModelOpt fine-tuned weights are ready.
Quality/perf findings are summarized on the page here, and quality against H200 FastVideo with the same input noisy latent/flow_shift value are summarized here.
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.